from .base_codec import BaseCodec

class IdentityCodec(BaseCodec):
    def __init__(self, 
                 mask_thr=None,
                 mask_value=None,
                ):
        super().__init__()
        self.trainable = False
        self.mask_thr = mask_thr
        self.mask_value = mask_value
        self._set_trainable()
    
    def get_tokens(self, x):
        assert x.dim() == 2, 'input token should be 2 dimensional with shape: B x L'
        out = {}
        if self.mask_thr is not None:
            mask = x > self.mask_thr
            if self.mask_value is not None:
                mask_ = x <= self.mask_thr
                x[mask_] = self.mask_value
            out['mask'] = mask
        out['token'] = x

        return out